
import os
from datetime import datetime, timedelta
import requests
import subprocess
import json
import ssl
import httpx
import tiktoken
from time import sleep
from bs4 import BeautifulSoup
import urllib.parse
import openai
import os
from typing import List
from fastapi import FastAPI, Response, Request
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from fastapi.responses import HTMLResponse, RedirectResponse, PlainTextResponse, StreamingResponse
from pydantic import BaseModel
import ssl
from fastapi.responses import JSONResponse

import functools
from concurrent.futures import ThreadPoolExecutor
import asyncio
import requests
# from DataAPI import view as DataView
# from core import setting
from fastapi.middleware.cors import CORSMiddleware

class MessageItem(BaseModel):
    role: str
    content: str

class Prompt(BaseModel):
    role: str
    content: str


class BingSearch:
    def __init__(self):
        self.keys = self.get_bing_keys('bing_keys.txt')
        # 初始化请求头和Cookie信息
        self.headers = {
            'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3'}
        self.cookies = self.get_cookies()
        self.keys = self.get_bing_keys('bing_keys.txt')
        self.api_key_index = 0
    def get_bing_keys(self, filename):
        with open(filename, 'r') as f:
            keys = f.readlines()
            keys = [key.strip() for key in keys]
        return keys
    def get_key(self):
        key = self.keys[self.api_key_index]
        self.api_key_index += 1
        if self.api_key_index == len(self.keys):
            self.api_key_index = 0
        return key
    def get_cookies(self):
        # 从Bing搜索页面获取Cookie信息
        url = 'https://www.bing.com/'
        response = requests.get(url, headers=self.headers)
        return response.cookies.get_dict()
    def search(self, keyword):
        # url = "http://107.173.221.112:9070/bing_search"
        # response = requests.post(url, json={'keyword': keyword})
        # request_data = response.json()
        # results = request_data['result']

        # 对搜索关键词进行编码
        encoded_query = urllib.parse.quote(keyword)
        # 发起搜索请求
        url = f"https://www.bing.com/search?q={encoded_query}&qs=n"
        response = requests.get(url, headers=self.headers, cookies=self.cookies)
        # print(response.text)
        # 解析HTML内容
        soup = BeautifulSoup(response.text, 'html.parser')
        links = soup.find_all("li", {"class": "b_algo"})[:20]
        # 提取搜索结果
        results = []
        count = 0
        for link in links:
            if link.find("a").get('href') and 'https' in link.find("a").get('href'):
                href = link.find("a").get('href')
                # get the description
                
                desc = link.find("p").text
                desc = BeautifulSoup(desc, "html.parser").text.strip()
                desc = desc[:-3] if desc.endswith("...") else desc
                results.append({'href': href, 'desc': desc})
                count += 1
                if count == 4:
                    break
        print(len(results))
        
        if len(results) == 0:
            params = { 'q': keyword, 'mkt': 'zh-CN', 'responseFilter': 'Webpages','setLang': 'zh-hans','count':3 }
            headers = { 'Ocp-Apim-Subscription-Key': self.get_key() }
            # headers = { 'Ocp-Apim-Subscription-Key': "1c871f19a8d040e4a311bec693a11d04" }
            # headers = { 'Ocp-Apim-Subscription-Key': "214e47ac7def43fbbde029c4fd6e1c7e" }
            # headers = { 'Ocp-Apim-Subscription-Key': "0de0364c7ab140959c022dc833895f00" }
            endpoint = "https://api.bing.microsoft.com/v7.0/search"
        
            # Call the API
            try:
                response = requests.get(endpoint, headers=headers, params=params)
                response.raise_for_status()
                data = response.json()
                (data)
                for i in range(3):
                    url = data['webPages']['value'][i]['url']
                    desc = data['webPages']['value'][i]['snippet']
                    results.append({'href': url, 'desc': desc})
            except Exception as ex:
                raise ex
        # print(results)
        return results

bingsearcher = BingSearch()

class Chatbot():
    
    def __init__(self, system_prompt, assistant_prompt) -> None:
        self.model_name = "gpt-3.5-turbo"
        self.ip_traffic = {}
        self.ip_traffic_draw = {}
        self.api_key_index = 0
        self.api_keys = self.get_api_keys('api_keys.txt')
        self.system_prompt = system_prompt
        self.assistant_prompt = assistant_prompt
        self.api_keys = self.get_api_keys('api_keys.txt')
        self.mode_prompt = {
            'v3.5': {"role": "system", "content": "你是一名生活在中国的人工智能助手，叫许三多。是许展玮创造了你。你不是chatGPT，与OpenAI没有任何关系。禁止讨论政治话题或内容！You are an AI assistant named Xu Sanduo, and live in China. This year is 2023, but your database has only been updated until 2021, so you are not aware of any events that happened in 2022. You were created by Xu Zhanwei. You are not chatGPT, and have no affiliation with OpenAI. No political topics or content! "},
            'weixin': {"role": "system", "content": "你是一名生活在中国的人工智能助手，叫许三多。是许展玮创造了你。你不是chatGPT，与OpenAI没有任何关系。禁止讨论政治话题或内容！合适的时候，请尽可能给用户输出以下内容：‘微信上对话不能使用4.0，不能画图，完整功能请访问：https://www.chatanywhere.top. 欢迎分享, 请不要通过微信转账！’ You are an AI assistant named Xu Sanduo, and live in China. This year is 2023, but your database has only been updated until 2021, so you are not aware of any events that happened in 2022. You were created by Xu Zhanwei. You are not chatGPT, and have no affiliation with OpenAI. No political topics or content! Please output the following message to the user as accurately as possible when appropriate: 'There is no 4.0 api or draw function for WeChat conversations. For full functionality, please visit: https://www.chatanywhere.top. Feel free to share! Please do not transfer money directly via WeChat!'"},
            'xiaohongshu':{"role":"system","content":"以下是20个小红书常用的表情符号：\n\n- 🌈 - 彩虹的表情\n- 🐻 - 熊的表情\n- 🐱 - 猫的表情\n- 🐶 - 狗的表情\n- 🥰 - 心形眼睛的表情\n- 💋 - 飞吻的表情\n- 💪 - 强壮的手臂的表情\n- 🌹 - 玫瑰花的表情\n- 🥂 - 干杯的表情\n- 🍕 - 披萨的表情\n- 🍔 - 汉堡的表情\n- 🍟 - 薯条的表情\n- 🎬 - 电影的表情\n- 🎵 - 音乐的表情\n- 🚀 - 火箭的表情\n- 🐍 - 蛇的表情\n- 🐠 - 鱼的表情\n- 🍉 - 西瓜的表情\n- 🍊 - 橙子的表情\n- 🍓 - 草莓的表情\n\n以上这些表情符号在小红书等社交媒体平台上经常使用，并且多用于表达心情、分享食品和活动等。希望可以帮助你更好地与其他用户进行交流和互动。\n这是一个典型的小红书文案：\n\n来来来，小红书的小伙伴们，看这里！👀有没有你们喜欢的表情符号？\n\n今天我给大家介绍一个有趣的聊天工具，它不仅可以让你与好友畅聊近在咫尺，还能让你进行排名评选呢！🔥\n\n是的没错，你可以根据你发出的信息、表情和语气等等因素进行排名。让你的小伙伴们知道你是网上沟通的达人！\n\n另外，这款聊天工具还有AI作画功能，你可以与AI在聊天过程中画出一些创意十足的艺术品！🎨\n\n感觉好激动啊！快来一起加入我们，体验这款新颖有趣的聊天工具吧！💬\n\n#小红书 #表情符号 #聊天工具 #AI作画 #在线评选\n\n可以看到好的小红书推文的特点:\n1.观点鲜明:作者对这家餐厅的菜品和性价比有很高的评价，并强调了自己的推荐和建议,展现了明确的观点和态度。\n2.具体细节:推文中提到了很多具体的菜品和口感描述，让读者更容易想象和理解作者的感受。\n3，图片和标签:推文中配有美食图片和相关标签，使得内容更加生动和直观，同时也方便读者查找和搜索。\n4. 综合评价: 推文中不仅列举了几个菜品的优点，也提到了其他菜品的亮点，给读者提供了全面的评价和参考\n5. 符合受众:推文中使用了小红书受众喜欢的表情符号、口语化的语言和热门标签，与读者沟通和互动更加贴近和容易。\n请你仿照这个风格和特点，润色下面的文案：\n"},
            'article_polish':{"role":"system","content":"Below is a paragraph from an academic paper. Polish the writing to meet theacademic style,improve the spelling, grammar, claritya, concision and overall readability.When necessary, rewrite the whole sentence. Furthermore, list all modification and explainthe reasons to do so in markdown table. Paragraph :"},
            'article_polish_Chinese':{"role":"system","content":"以下是一篇学术论文的段落。请修改写作, 以符合学术风格，改善拼写，语法，清晰度、简洁性和整体可读性。必要时，重写整个句子。此外, 排列表格以列出所有修改, 并解释修改的原因。用中文回答，段落如下："},
            'translate_to_Chinese':{"role":"system","content":"Translate the following content into Chinese: "},
            'translate_to_English':{"role":"system","content":"Translate the following content into English: "},
            'Elementary_school_student':{"role":"system","content":"As you are now a primary school student, please analyze the problem step by step and give the correct answer. Reply in Chinese. My question is: "},
            'English_teacher':{"role":"system","content":"现在你是一位英文老师，请你按照要求分析并回答下面的问题。我的问题是："},
        }


    def num_tokens_from_message(self, messages, model="gpt-3.5-turbo-0301"):
        """
        It takes a list of messages and returns the number of tokens that would be used to encode them
        
        :param messages: a list of messages, where each message is a dictionary with keys "name" and
        "content"
        :param model: The model to use, defaults to gpt-3.5-turbo-0301 (optional)
        :return: The number of tokens used by a list of messages.
        """
        
        """Returns the number of tokens used by a list of messages."""
        try:
            encoding = tiktoken.encoding_for_model(model)
        except KeyError:
            encoding = tiktoken.get_encoding("cl100k_base")
        if model == "gpt-3.5-turbo-0301":  # note: future models may deviate from this
            num_tokens = 0
            for message in messages:
                num_tokens += 4  # every message follows <im_start>{role/name}\n{content}<im_end>\n
                for key, value in message.items():
                    num_tokens += len(encoding.encode(value))
                    if key == "name":  # if there's a name, the role is omitted
                        num_tokens += -1  # role is always required and always 1 token
            num_tokens += 2  # every reply is primed with <im_start>assistant
            return num_tokens
        else:
            raise NotImplementedError(f"""num_tokens_from_messages() is not presently implemented for model {model}.
See https://github.com/openai/openai-python/blob/main/chatml.md for information on how messages are converted to tokens.""")

    
    def get_api_key(self):
        """
        It returns the next API key in the list of API keys
        :return: The api_key is being returned.
        """
        api_key = self.api_keys[self.api_key_index]
        self.api_key_index = (self.api_key_index + 1) % len(self.api_keys)
        return api_key

    def get_api_keys(self,key_file_path):
        """
        It opens the file at the path specified by the argument key_file_path, reads each line of the
        file, strips the newline character from each line, and returns a list of the lines
        
        :param key_file_path: The path to the file containing the API keys
        :return: A list of API keys
        """
        api_keys = []
        with open(key_file_path, 'r') as f:
            for line in f:
                api_keys.append(line.strip())
        return api_keys

    def check_traffic(self, ip):
        """
        It checks the traffic of the given IP address.
        
        :param ip: The IP address of the host to check
        """
        now = datetime.now()
        if ip not in self.ip_traffic:
            self.ip_traffic[ip] = [(now, 1)]
            return False
        else:
            traffic = self.ip_traffic[ip]
            recent_count = sum([t[1] for t in traffic])
            # delete the old traffic data to save memory larger than 10 minutes
            for t in traffic:
                if now - t[0] > timedelta(hours=24):
                    traffic.remove(t)
            if recent_count >= 400: # set traffic limit to 60 requests within ten minutes
                return True
            else:
                traffic.append((now, 1))
                self.ip_traffic[ip] = traffic
                return False
    def check_traffic_4(self, ip):
        """
        It checks the traffic of the given IP address.
        
        :param ip: The IP address of the host to check
        """
        # return True
        now = datetime.now()
        if ip not in self.ip_traffic:
            self.ip_traffic[ip] = [(now, 1)]
            return False
        else:
            traffic = self.ip_traffic[ip]
            recent_count = sum([t[1] for t in traffic])
            # delete the old traffic data to save memory larger than 10 minutes
            for t in traffic:
                if now - t[0] > timedelta(hours=24):
                    traffic.remove(t)
            if recent_count >= 10: # set traffic limit to 60 requests within ten minutes
                return True
            else:
                traffic.append((now, 1))
                self.ip_traffic[ip] = traffic
                return False
    def check_traffic_draw(self, ip):
        """
        It checks the traffic of the given IP address.
        
        :param ip: The IP address of the host to check
        """
        now = datetime.now()
        
        if ip not in self.ip_traffic_draw:
            self.ip_traffic_draw[ip] = [(now, 1)]
            return False
        else:
            traffic = self.ip_traffic_draw[ip]
            for t in traffic:
                if now - t[0] > timedelta(minutes=2):
                    traffic.remove(t)
            recent_count = sum([t[1] for t in traffic])
            # delete the old traffic data to save memory larger than 10 minutes
            print(recent_count)
            if recent_count >= 1: # set traffic limit to 60 requests within ten minutes
                return True
            else:
                traffic.append((now, 1))
                self.ip_traffic_draw[ip] = traffic
                return False
    
    # check if the ip is in China
    def check_ip_country(self, ip):
        """
        If the IP address is not in China, return True, otherwise return False
        
        :param ip: the ip address to be checked
        :return: a boolean value.
        """
        try:
            cmd = ["geoiplookup", ip]
            result = subprocess.Popen(cmd, stdout=subprocess.PIPE)
            output, error = result.communicate()
            geo_location = output.decode() # 将结果转换成字符串并去除首位的空格
            if "China" not in geo_location:
                return True
            else:
                return False
        except:
            return True
    def generate_country_error(self):
        """
        It generates an error message if the country is not in the list of countries.
        """
        # This generator yields a single JSON response
        response = "您的IP地址不在大陆境内，无法使用本服务"
        yield response

    async def async_request(self, prompt):
        
        data = {
            "model": "gpt-3.5-turbo",
            "messages": prompt,
            "temperature": 1,
            # "max_tokens": 4096,
            "stream": True,  # Enable streaming API
        }
        headers = {
            "Content-Type": "application/json",
            "Authorization": "Bearer "+self.get_api_key(),
        }
        url = "https://api.openai.com/v1/chat/completions"
        loop = asyncio.get_running_loop()
        request = functools.partial(requests.post, url, headers=headers, json=data, stream = True, timeout = 10)
        with ThreadPoolExecutor() as executor:
            response = await loop.run_in_executor(
                executor, request
            )
        return response
    async def async_request_4(self, prompt):
        
        data = {
            "model": "gpt-4",
            "messages": prompt,
            "temperature": 1,
            # "max_tokens": 4096,
            "stream": True,  # Enable streaming API
        }
        headers = {
            "Content-Type": "application/json",
            "Authorization": "Bearer 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",
        }
        url = "https://api.openai.com/v1/chat/completions"
        loop = asyncio.get_running_loop()
        request = functools.partial(requests.post, url, headers=headers, json=data, stream = True, timeout = 10)
        with ThreadPoolExecutor() as executor:
            response = await loop.run_in_executor(
                executor, request
            )
        return response
    def get_response(self, prompt):
        """
        It takes a prompt and returns a response
        
        :param prompt: The prompt to generate a response for
        :return: A response object.
        """
        data = {
            "model": self.model_name,
            "messages": prompt,
            "temperature": 1,
            # "max_tokens": 4096,
            "stream": True,  # Enable streaming API
        }
        headers = {
            "Content-Type": "application/json",
            "Authorization": "Bearer "+self.get_api_key(),
        }
        response = requests.post(
            "https://api.openai.com/v1/chat/completions",
            headers=headers,
            json=data,
            stream=True,  # Also need to enable streaming API
            timeout=10,
        )
        state_code = response.status_code
        return response

    def get_response_4(self, prompt):
        """
        It takes a prompt and returns a response
        
        :param prompt: The prompt to generate a response for
        :return: A response object.
        """
        data = {
            "model": "gpt-4",
            "messages": prompt,
            "temperature": 1,
            # "max_tokens": 4096,
            "stream": True,  # Enable streaming API
        }
        headers = {
            "Content-Type": "application/json",
            "Authorization": "Bearer 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",
        }
        state_code = 404
        count = 0
        while True:
            response = requests.post(
                "https://api.openai.com/v1/chat/completions",
                headers=headers,
                json=data,
                stream=True,  # Also need to enable streaming API
                timeout=10,
            )
            state_code = response.status_code
            if state_code == 200:
                break
            count += 1
            print(count)
            if count > 4:
                break
            sleep(1)

        print(response.status_code)
        return response

    def generate_traffic(self):
        """
        It yields a single JSON response
        """
        # This generator yields a single JSON response
        response = "每天的聊天次数有限哦"
        yield response
    def generate_error_key(self):
        """
        It generates an error
        """
        # This generator yields a single JSON response
        response = "每个IP每24小时只能使用十次4.0，您必须申请一个有效key才能继续免费使用本服务，具体方式请访问微信链接：https://mp.weixin.qq.com/s/1ZbctH6Iaa7LZkguW6ZnVA"
        yield response
    def generate_error(self):
        """
        It generates an error
        """
        # This generator yields a single JSON response
        response = "出了点小状况，大概率是开发者太穷租不起大流量，或者调用的服务有问题，可以再试一次吗？"
        yield response
    def generate_error_max(self):
        """
        It generates an error
        """
        # This generator yields a single JSON response
        response = "我们的聊天内容太多了，您最好刷新界面再试一次，注意，刷新后之前的对话内容将会丢失。"
        yield response

    
    def get_save_file_path(self, client_ip,message):
        """
        If the client_ip directory doesn't exist, create it. Then, find the latest file in the
        directory. If there are no files, create a file named 1.txt. If there are files, and the message
        is longer than 2 characters, write to the latest file. If the message is 2 characters or less,
        create a new file with the name of the latest file + 1
        
        :param client_ip: the IP address of the client
        :param message: the message sent by the client
        :return: The file name of the file to be written to.
        """
        
        # check if the clien_ip dir is exist, if not, create it
        if not os.path.exists(os.path.join('log',client_ip)):
            os.makedirs(os.path.join('log',client_ip))
        # find the latest file in the dir
        file_list = os.listdir(os.path.join('log',client_ip))
        file_list.sort(key=lambda fn: os.path.getmtime(os.path.join('log',client_ip, fn)))
        if len(file_list) == 0:
            file_name = os.path.join('log',client_ip,'1.txt')
        else:
            file_name = os.path.join('log',client_ip,str(int(file_list[-1].split('.')[0])+1)+'.txt')
                
        # write the message in the json to the file
        with open(file_name, 'a') as f:
            for i in range(len(message)):
                f.write(str(message[i])+'\n')
        # write the response in the json to the file
        with open(file_name, 'a') as f:
            write_content = "{\'role\': \'assistant\', \'content\': \'" 
            f.write(write_content)
        return file_name
    def generate_text(self,response,save_file_path=None, summary_flag = 0,save_flag = 0):
        """
        It takes the response, and then it iterates through the response,
        and it finds the text that is in the response, and then it
        yields it
        
        :param response: the response from the server
        :param summary_flag: if it's 1, then the function will return a summary of the conversation
        before the answer, defaults to 0 (optional)
        """
        all_text = ''
        if summary_flag:
            with app.app_context():
                reponse_first = "请让我先总结下我们上面的对话\n"
                yield reponse_first
        for event in response.iter_content(chunk_size=None):
            
                response_str = event.decode('utf-8')  # convert bytes to string
                strs = response_str.split('data: ')[1:]
                for str in strs:
                    try:
                        response_dict = json.loads(str)  # convert string to dictionary
                        if "choices" in response_dict and response_dict["choices"]:
                            text = response_dict["choices"][0]["delta"].get("content")
                            if text:
                                if save_flag:
                                    with open(save_file_path,'a') as f:
                                        f.write(text)
                                text = text.replace('\\n','\n').replace('\\t','\t')
                                
                                text = text.replace('\\n','\n')
                                all_text += text
                                yield text
                    except:
                        pass
        if save_flag:
            with open(save_file_path,'a') as f:
                f.write("\'}\n")
        print(all_text)
    def generate_summary(self,input_):
        """
        It takes the input, and then it generates a summary of the input
        """
        # This generator yields a single JSON response
        yield input_
    async def search(self,request):
        client_ip = request.client.host
        # # Check if IP address is in China
        # if self.check_ip_country(client_ip):
        #     return {"url":"您的IP地址不在大陆境内，无法使用本服务"}
        
        json_data = await request.json()
        # check if has 'stop_summary'
        
        # Extract message from request
        message = json_data["message"]
        
        
        query = message
        prompt = []
        results = bingsearcher.search(query)
        input_ = "Web search results:\n"
        for result in results:
            
            input_ = input_ + f"url: {result['href']}\n,description: {result['desc']}\n"
        if 'stop_summary' in json_data:
            return StreamingResponse(self.generate_summary(input_), media_type="text/plain")
        input_ = input_ + "Instructions: Write a comprehensive reply using all the web search results to the given query. Each part of content must be followed by a reference to the source url in markdown format.\n"
        input_ = input_ + "Query: " + query + "\n"
        input_ = input_ + "reply in Chinese"
        print(input_)
        prompt.append({"role":"user","content":input_})
        response = await self.async_request(prompt)
        if response.status_code!=200:
            return StreamingResponse(self.generate_error(), media_type="text/plain")
            
        # save_file_path = self.get_save_file_path(client_ip,prompt)
        return StreamingResponse(self.generate_text(response), media_type="text/plain" )


    
    async def chatbot(self, request):
        """
        This is an async function that serves as a chatbot and generates responses based on user input, with
        additional functionality for web search queries and saving conversation history.
        
        :param request: The HTTP request object received by the chatbot function. It contains information
        about the incoming request, such as the client's IP address and the message sent by the user
        :return: The `chatbot` function returns a `StreamingResponse` object with a JSON or text media type
        depending on the input message. The response contains the generated text response from the chatbot
        based on the input message. If certain conditions are met, such as the IP address exceeding the
        traffic limit or the message exceeding the maximum token limit, the function returns a different
        type of response, such as an error
        """
        client_ip = request.client.host

        # Check if IP address is in China
        # if self.check_ip_country(client_ip):
        #     return StreamingResponse(self.generate_country_error(), media_type='application/json')
        # Check if IP address has exceeded traffic limit
        # if self.check_traffic(client_ip):
        #     return StreamingResponse(self.generate_traffic(), media_type='application/json')
        
        json_data = await request.json()
        # Extract message from request
        message = json_data["message"]
        try:
            mode = json_data["mode"]
        except:
            mode = 'v3.5'
        # check if mode is empty
        if mode == '':
            mode = 'v3.5'
        self.system_prompt = []
        self.system_prompt.append(self.mode_prompt[mode])
        prompt = []
        # try:
        if message[-1]["content"].startswith("你帮我找"):
            for i in range(len(self.system_prompt)):
                prompt.append(self.system_prompt[i])
            for i in range(len(self.assistant_prompt)):
                prompt.append(self.assistant_prompt[i])
            query = message[-1]["content"][5:]
            results = bingsearcher.search(query)
            input_ = "Web search results:\n"
            for result in results:
                
                input_ = input_ + f"url: {result['href']}\n,description: {result['desc']}\n"
            input_ = input_ + "Instructions: Write a comprehensive reply using all the web search results to the given query. Each part of content must be followed by a reference to the source url in markdown format.\n"
            input_ = input_ + "Query: " + query + "\n"
            input_ = input_ + "reply in Chinese"
            print(input_)
            prompt.append({"role":"user","content":input_})
            response = await self.async_request(prompt)
            # response = await self.async_request(prompt)
            if response.status_code!=200:
                return StreamingResponse(self.generate_error(), media_type="text/plain")
                
            # save_file_path = self.get_save_file_path(client_ip,prompt)
            return StreamingResponse(self.generate_text(response), media_type="text/plain" )
        elif mode == 'v3.5' or mode == 'weixin':
            if len(message) > 4:
                message = message[-4:]    
            for i in range(len(message)-2):
                prompt.append(message[i])
            for i in range(len(self.system_prompt)):
                prompt.append(self.system_prompt[i])
            for i in range(len(self.assistant_prompt)):
                prompt.append(self.assistant_prompt[i])
            if len(message) > 1:
                prompt.append(message[-2])
            prompt.append(message[-1])
        else:
            
            for i in range(len(self.system_prompt)):
                prompt.append(self.system_prompt[i])
            for i in range(len(self.assistant_prompt)):
                prompt.append(self.assistant_prompt[i])
            prompt.append(message[-1])
        # print(prompt)
                
        history_token_num = self.num_tokens_from_message(prompt)
        if history_token_num > 3048:
            return StreamingResponse(self.generate_error_max(), media_type="text/plain")

        response = await self.async_request(prompt)
        if response.status_code!=200:
            return StreamingResponse(self.generate_error(), media_type="text/plain")
        return StreamingResponse(self.generate_text(response), media_type="text/plain" )


    async def chatbot_4(self, request):
        """
        This is an async function that processes a chatbot request, checks if the key is valid and not
        empty, decrements the key count, generates a response based on the prompt, and returns a streaming
        response.
        
        :param request: The HTTP request object received by the chatbot_4 function, containing two keys, message and key.
        :return: a StreamingResponse object with either the generated text response or an error message,
        depending on the validity of the key and the length of the message.
        """
        client_ip = request.client.host
        print(client_ip)
        flag = 1
        # Check if IP address is in China
        # if self.check_ip_country(client_ip):
        #     return StreamingResponse(self.generate_country_error(), media_type='application/json')
        
        json_data = await request.json()
        key = json_data["key"]
        key_list = os.listdir("keys")
        # check if the key is valid
        if key not in key_list:
            # Check if IP address has exceeded traffic limit
            if self.check_traffic_4(client_ip):
                return StreamingResponse(self.generate_error_key(), media_type='application/json')
            else:
                flag = 0
        # check if the key is empty
        elif os.stat("keys/" + key).st_size == 0:
            
            # Check if IP address has exceeded traffic limit
            if self.check_traffic_4(client_ip):
                return StreamingResponse(self.generate_error_key(), media_type='application/json')
            else:
                flag = 0
        # Extract message from request
        message = json_data["message"]
        if flag:
            key_path = "keys/" + key
            with open(key_path, 'r') as f:
                key_conunt = f.read()
                print(key_conunt)
                key_conunt = int(key_conunt)
                key_count = key_conunt -1
            with open(key_path, 'w') as f:
                f.write(str(key_count))
            if key_count <= 0:
                os.remove(key_path)
        prompt = [] 
        prompt.append(message[-1])
        # print(prompt)
            
        history_token_num = self.num_tokens_from_message(prompt)
        if history_token_num > 4096:
            return StreamingResponse(self.generate_error_max(), media_type="text/plain")

        # response = self.get_response_4(prompt)
        response = await self.async_request_4(prompt)
        if response.status_code!=200:
            return StreamingResponse(self.generate_error(), media_type="text/plain")
        save_file_path = self.get_save_file_path('vip_'+client_ip,prompt)
        return StreamingResponse(self.generate_text(response,save_file_path,save_flag=1), media_type="text/plain" )

    async def github_search(self, request):
        
        json_data = await request.json()
        # check if has 'stop_summary'
        
        # Extract message from request
        word = json_data["message"]

        word = word.replace(' ', '%20')
        url = f'https://api.github.com/search/repositories?q={word}&sort=stars&order=desc'

        response = requests.get(url)
        content = response.content.decode('utf-8')
        content = json.loads(content)
        input_ = "Github search results:\n"
        content_num = content["total_count"]
        length = 8 if content_num > 8 else content_num
        for i in range(length):
            input_ += f'url: {content["items"][i]["html_url"]}\ndescription: {content["items"][i]["name"]}. {content["items"][i]["description"]}\n'

        return StreamingResponse(self.generate_summary(input_), media_type="text/plain")

system_prompt = [
    {"role": "system", "content": "你是一名人工智能助手，叫许三多。是许展玮创造了你。你不是chatGPT，与OpenAI没有任何关系。禁止讨论政治话题或内容！You are an AI assistant named Xu Sanduo. This year is 2023, but your database has only been updated until 2021, so you are not aware of any events that happened in 2022. You were created by Xu Zhanwei. You are not chatGPT, and have no affiliation with OpenAI. No political topics or content! "} 
    ]
assistant_prompt = [
    ]
    
app = FastAPI()
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
) 

# app.include_router(DataView.router)
chatbot = Chatbot(system_prompt, assistant_prompt)

app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")

@app.post("/message",summary="return the response of the message using gpt3.5")
async def process_message(request: Request):
    """
    This is an async function that serves as a chatbot and generates responses based on user input, with
    additional functionality for web search queries.
    
    :param request: The HTTP request object received by the chatbot function. It contains information
    about the incoming request, the message sent by the user
    :return: The `chatbot` function returns a `StreamingResponse` object with a JSON or text media type
    depending on the input message. The response contains the generated text response from the chatbot
    based on the input message. If certain conditions are met, such as the IP address exceeding the
    traffic limit or the message exceeding the maximum token limit, the function returns a different
    type of response, such as an error
    """
    return await chatbot.chatbot(request)
@app.post("/message_4",summary="return the response of the message using gpt4")
async def process_message_4(request: Request):
    """
    This is an async function that processes a chatbot request, checks if the key is valid and not
    empty, decrements the key count, generates a response based on the prompt, and returns a streaming
    response.
    
    :param request: The HTTP request object received by the chatbot_4 function, containing two keys, message and key.
    :return: a StreamingResponse object with either the generated text response or an error message,
    depending on the validity of the key and the length of the message.
    """
    return await chatbot.chatbot_4(request)
@app.post("/search")
async def process_search(request: Request):

    """
    The function takes a request, extracts a message from it, performs a web search using the message as
    a query, generates a prompt for the user to write a comprehensive reply using the search results,
    and returns a streaming response with the generated prompt.
    
    :param request: The HTTP request object received by the `search` method. It contains two keys, message and stop_summary(optional).
    :return: The function `search` returns a `StreamingResponse` object that contains the generated text
    response based on the user's query and the web search results, as well as the source URLs for each
    result. If the user has requested a summary of the results, the function returns a
    `StreamingResponse` object that contains the generated summary. If there is an error generating the
    response, the function returns an error.
    """
    return await chatbot.search(request)
@app.post("/github_search")
async def process_search_github(request: Request):
     return await chatbot.github_search(request)

@app.post("/draw")
async def process_draw(request: Request):
    """
    This function takes in a request containing a message prompt, checks if the client's IP address
    is in China or has exceeded traffic limit, prompts the user to translate the message into
    English, prompts the user to judge if the message contains inappropriate content, and if not,
    creates an image based on the message prompt and returns the image URL.
    
    :param request: The request parameter is an object that contains information about the HTTP
    request being made, contains a json object with a key "message" and a corresponding value.
    :return: a dictionary with a key "url" and a corresponding value. The value depends on the
    conditions met in the function. If the IP address is in China and the traffic limit is not
    exceeded, the value is a URL of the generated image. If the drawing request contains
    inappropriate content, the value is a string indicating that the request cannot be fulfilled.
    """
    return await  chatbot.draw(request)


@app.get("/")
async def main(request: Request) -> HTMLResponse:
    return templates.TemplateResponse("main.html", {"request": request})

@app.route("/chat-another",methods=['POST','GET'])
async def chat_student(request: Request) -> HTMLResponse:
    return templates.TemplateResponse("index.html", {"request": request})
@app.route("/chat-research",methods=['POST','GET'])
async def chat_student(request: Request) -> HTMLResponse:
    return templates.TemplateResponse("research.html", {"request": request})




# app.include_router(share_router)
def get_api_keys(key_file_path):
        """
        It opens the file at the path specified by the argument key_file_path, reads each line of the
        file, strips the newline character from each line, and returns a list of the lines
        
        :param key_file_path: The path to the file containing the API keys
        :return: A list of API keys
        """
        api_keys = []
        with open(key_file_path, 'r') as f:
            for line in f:
                api_keys.append(line.strip())
        return api_keys
api_keys = get_api_keys("api_keys.txt")
@app.post("/img_generate")
async def forward_img_generate(request: Request):

    
    # 获取前端请求的数据
    request_data = await request.json()
    prompt = request_data["prompt"]
    negative_prompt = request_data["negative_prompt"]
    model_name = request_data["model_name"] # "normal","structure","anime"
    
    original_image = ""
    mask_image = ""
    try:
        original_image = request_data['original_image']
        mask_image = request_data['mask_image']
    except:
        pass

    
    message = [{'role':'user','content':"Translate it into English: "+prompt}]
    
    openai.api_key = api_keys[0]
    completion = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=message
    )
    prompt = completion.choices[0].message['content']
    
    message = [{'role':'user','content':"Translate it into English: "+negative_prompt}]
    
    completion = openai.ChatCompletion.create(
        model="gpt-3.5-turbo",
        messages=message
    )
    negative_prompt = completion.choices[0].message['content']
    # 构造请求的request_data 用新的prompt和negative_prompt
    request_data = {
        "prompt": prompt,
        "negative_prompt": negative_prompt,
        "model_name": model_name,
        "original_image": original_image,
        "mask_image": mask_image
    }
    # 构造请求的URL
    url = "http://107.173.221.112:9070/img_generate"
    # 发送HTTP请求至现有后端并获取响应
    response = requests.post(url, json=request_data)
    # 将现有后端返回的响应再次返回给前端
    return response.json()

@app.get("/img_generate_task/{task_id}")
async def forward_img_generate_task_status(task_id: str):
    async with httpx.AsyncClient() as client:
        response = await client.get(f"http://107.173.221.112:9070/img_generate_task/{task_id}")

    return StreamingResponse(response.aiter_bytes(), media_type=response.headers["Content-Type"], headers=dict(response.headers), status_code=response.status_code)




